"""
This file contains networks that are capable of handling (batch, time, [applicable features])
"""
import copy

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from continual_rl.utils.common_nets import get_network_for_size, CommonConv
from continual_rl.utils.utils import Utils


class ImpalaNet(nn.Module):
    """
    Based on Impala's AtariNet, taken from:
    https://github.com/facebookresearch/torchbeast/blob/6ed409587e8eb16d4b2b1d044bf28a502e5e3230/torchbeast/monobeast.py
    """

    def __init__(self, observation_spaces, action_spaces, model_flags, conv_net=None):
        super().__init__()
        self.use_lstm = model_flags.use_lstm
        conv_net_arch = model_flags.conv_net_arch

        self.action_type = None  # 离散或者连续的动作类型
        max_action_space = Utils.get_max_action_space(action_spaces)
        self._observation_space = Utils.get_max_observation_space(observation_spaces)
        if max_action_space.__class__.__name__ == "Discrete":
            self.action_type = "Discrete"
            self.num_actions = max_action_space.n
        else:
            self.action_type = "Box"
            self.num_actions = max_action_space.shape[0]

        self._model_flags = model_flags
        self._action_spaces = action_spaces  # The max number of actions - the policy's output size is always this
        self._current_action_size = None  # Set by the environment_runner
        self._observation_spaces = observation_spaces

        if conv_net is None:
            if len(self._observation_space.shape) >= 3:
                # The conv net gets channels and time merged together (mimicking the original FrameStacking)
                combined_observation_size = [self._observation_space.shape[0] * self._observation_space.shape[1],
                                             self._observation_space.shape[2],
                                             self._observation_space.shape[3]]
                self._conv_net = get_network_for_size(combined_observation_size, arch=conv_net_arch)
            else:
                # 对于非图像观测，将时间维度和原本的维度相乘得到输入维度
                self._conv_net = nn.Sequential(
                    nn.Linear(self._observation_space.shape[0] * self._observation_space.shape[1], 64),
                    nn.ReLU(),
                    nn.Linear(64, 64),
                    nn.ReLU(),
                )
        else:
            self._conv_net = conv_net

        if isinstance(self._conv_net, CommonConv):
            # FC output size + one-hot of last action + last reward.
            core_output_size = self._conv_net.output_size + self.num_actions + 1
        else:
            core_output_size = 64 + self.num_actions + 1

        if self.action_type == "Box":
            # 处理连续动作空间
            self.policy_mean = nn.Linear(core_output_size, self.num_actions)
            self.policy_log_std = nn.Parameter(torch.zeros(self.num_actions))
        else:
            self.policy = nn.Linear(core_output_size, self.num_actions)

        self._baseline_output_dim = 2 if model_flags.baseline_includes_uncertainty else 1

        # The first output value is the standard critic value. The second is an optional value the policies may use
        # which we call "uncertainty".
        if model_flags.baseline_extended_arch:
            self.baseline = nn.Sequential(
                nn.Linear(core_output_size, 32),
                nn.ReLU(),
                nn.Linear(32, 32),
                nn.ReLU(),
                nn.Linear(32, self._baseline_output_dim)
            )
        else:
            self.baseline = nn.Linear(core_output_size, self._baseline_output_dim)

        # used by update_running_moments()
        # second moment is variance
        self.register_buffer("reward_sum", torch.zeros(()))
        self.register_buffer("reward_m2", torch.zeros(()))
        self.register_buffer("reward_count", torch.zeros(()).fill_(1e-8))

    def initial_state(self, batch_size):
        assert not self.use_lstm, "LSTM not currently implemented. Ensure this gets initialized correctly when it is" \
                                  "implemented."
        return tuple()

    def forward(self, inputs, action_space_id, core_state=()):
        x = inputs["frame"]  # [T, B, S, C, H, W]. T=timesteps in collection, S=stacked frames
        # if x.shape[-1] != self._observation_space.shape[-1] or x.shape[-2] != self._observation_space.shape[-2]:
        #     x = Utils.padding_state(x, self._observation_space.shape)

        T, B, *_ = x.shape
        x = torch.flatten(x, 0, 1)  # Merge time and batch.
        x = torch.flatten(x, 1, 2)  # Merge stacked frames and channels.
        if self._observation_space.high.max() != np.inf:
            # 状态空间范围无限的环境不能归一化
            x = x.float() / self._observation_space.high.max()
        else:
            x = x.float()
        x = self._conv_net(x)
        x = F.relu(x)

        one_hot_last_action = F.one_hot(
            inputs["last_action"].view(T * B), self.num_actions
        ).float()
        clipped_reward = torch.clamp(inputs["reward"], -1, 1).view(T * B, 1).float()
        core_input = torch.cat([x, clipped_reward, one_hot_last_action], dim=-1)

        if self.use_lstm:
            core_input = core_input.view(T, B, -1)
            core_output_list = []
            notdone = (~inputs["done"]).float()
            for input, nd in zip(core_input.unbind(), notdone.unbind()):
                # Reset core state to zero whenever an episode ended.
                # Make `done` broadcastable with (num_layers, B, hidden_size)
                # states:
                nd = nd.view(1, -1, 1)
                core_state = tuple(nd * s for s in core_state)
                output, core_state = self.core(input.unsqueeze(0), core_state)
                core_output_list.append(output)
            core_output = torch.flatten(torch.cat(core_output_list), 0, 1)
        else:
            core_output = core_input
            core_state = tuple()

        baseline = self.baseline(core_output)

        if self.action_type == "Box":
            # 处理连续动作空间
            mean = self.policy_mean(core_output)
            log_std = self.policy_log_std.expand_as(mean)
            std = torch.exp(log_std)

            if self.training:
                # 采用重参数化技巧的高斯分布
                action_dist = torch.distributions.Normal(mean, std)
                action = action_dist.rsample()
            else:
                # 测试模式下，直接使用均值作为动作输出
                action = mean

            action = action.view(T, B, self.num_actions)
            policy_logits = mean
        else:
            # 处理离散动作空间
            policy_logits = self.policy(core_output)

            # Used to select the action appropriate for this task (might be from a reduced set)
            current_action_size = self._action_spaces[action_space_id].n
            # 只取符合当前动作空间大小的部分
            if current_action_size < policy_logits.shape[-1]:
                policy_logits_subset = policy_logits[:, :current_action_size]
            else:
                policy_logits_subset = policy_logits

            if self.training:
                action = torch.multinomial(F.softmax(policy_logits_subset, dim=1), num_samples=1)
            else:
                # Don't sample when testing.
                action = torch.argmax(policy_logits_subset, dim=1)
            action = action.view(T, B)

        policy_logits = policy_logits.view(T, B, -1)
        baseline = baseline.view(T, B, self._baseline_output_dim)

        output_dict = dict(policy_logits=policy_logits, baseline=baseline[:, :, 0], action=action)

        if self._model_flags.baseline_includes_uncertainty:
            output_dict["uncertainty"] = baseline[:, :, 1]

        return output_dict, core_state

    # from https://github.com/MiniHackPlanet/MiniHack/blob/e124ae4c98936d0c0b3135bf5f202039d9074508/minihack/agent/polybeast/models/base.py#L67
    @torch.no_grad()
    def update_running_moments(self, reward_batch):
        """Maintains a running mean of reward."""
        new_count = len(reward_batch)
        new_sum = torch.sum(reward_batch)
        new_mean = new_sum / new_count

        curr_mean = self.reward_sum / self.reward_count
        new_m2 = torch.sum((reward_batch - new_mean) ** 2) + (
                (self.reward_count * new_count)
                / (self.reward_count + new_count)
                * (new_mean - curr_mean) ** 2
        )

        self.reward_count += new_count
        self.reward_sum += new_sum
        self.reward_m2 += new_m2

    @torch.no_grad()
    def get_running_std(self):
        """Returns standard deviation of the running mean of the reward."""
        return torch.sqrt(self.reward_m2 / self.reward_count)
